21. Transformer 模型训练和评估#

21.1. 介绍#

前面的章节中,我们介绍了 Transformer 模型的基本结构和工作原理,并实现一个完整的基于 Transformer 模型的加法计算模型。在这一章节中,我们将重点关注 Transformer 模型的训练和评估过程。

21.2. 环境配置#

21.2.1. 安装依赖#

!pip install --upgrade dsxllm

21.2.2. 环境版本#

from dsxllm.util import show_version

show_version()
/Users/kong/opt/anaconda3/envs/dsx-ai/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
本书愿景:
+------+--------------------------------------------------------+
| Info |                  《动手学大语言模型》                  |
+------+--------------------------------------------------------+
| 作者 |                       吾辈亦有感                       |
| 哔站 |      https://space.bilibili.com/3546632320715420       |
| 定位 | 基于'从零构建'的理念,用实战帮助程序员快速入门大模型。 |
| 愿景 | 若让你的AI学习之路走的更容易一点,我将倍感荣幸!祝好😄 |
+------+--------------------------------------------------------+
环境信息:
+-------------+--------------+------------------------+
| Python 版本 | PyTorch 版本 | PyTorch Lightning 版本 |
+-------------+--------------+------------------------+
|   3.12.12   |    2.10.0    |         2.6.1          |
+-------------+--------------+------------------------+

21.3. 初始化模型和训练器#

from dsxllm.transformer.tokenizer import get_tokenizer
from dsxllm.transformer.dataset import TextTransform, TextDataModule
from dsxllm.transformer.model import Transformer

import lightning as L

# 超参配置
encoder_max_length = 7  # 编码器输入最大长度
decoder_max_length = 6  # 解码器输入最大长度

batch_size = 100  # 批次大小
d_model = 128  # 模型维度
feedforward_size = 512  # 前馈神经网络维度
n_layers = 4  # 编码器和解码器层数
learning_rate = 0.0001  # 学习率


# 1️⃣ 初始化分词器
tokenizer = get_tokenizer()

# 2️⃣ 初始化编码器和解码器的数据转换器
encoder_transform = TextTransform(tokenizer, max_length=encoder_max_length)
decoder_transform = TextTransform(tokenizer, max_length=decoder_max_length)

# 3️⃣ 加载数据模组
datamodule = TextDataModule(
    batch_size=batch_size,
    encoder_transform=encoder_transform,
    decoder_transform=decoder_transform,
    train_data_file="./dataset/addition_train.txt",
    val_data_file="./dataset/addition_val.txt",
)

# 4️⃣ 初始化模型
model = Transformer(
    tokenizer,
    d_model,
    feedforward_size,
    n_layers=n_layers,
    learning_rate=learning_rate,
    encoder_max_length=encoder_max_length,
    decoder_max_length=decoder_max_length - 1,
)

# 5️⃣ 初始化训练器
trainer = L.Trainer(
    max_epochs=12,  # 最大训练轮数
    log_every_n_steps=3,  # 每 3 个批次打印一次日志
    check_val_every_n_epoch=1,  # 每轮训练验证一次
    num_sanity_val_steps=0,  # 训练前不进行验证
    enable_progress_bar=False,  # 不显示进度条
)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
/Users/kong/opt/anaconda3/envs/dsx-ai/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.

21.4. 训练前评估#

训练前评估为模型性能建立初始性能基准。

# 直接调用验证函数进行评估
trainer.validate(model=model, datamodule=datamodule)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
/Users/kong/opt/anaconda3/envs/dsx-ai/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
***** Validation: 样本总数 5000  正确预测: 0  正确率: 0.0000 *****
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric             DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     correct_sequences                 0.0            │
│      correct_tokens                  705.0           │
│          seq_acc                      0.0            │
│         token_acc             0.03403495252132416    │
│      total_sequences                5000.0           │
│       total_tokens                  20714.0          │
└───────────────────────────┴───────────────────────────┘
[{'total_sequences': 5000.0,
  'correct_sequences': 0.0,
  'seq_acc': 0.0,
  'total_tokens': 20714.0,
  'correct_tokens': 705.0,
  'token_acc': 0.03403495252132416}]

21.5. 训练模型#

调用 trainer.fit() 训练 12 个轮次。

model.clear_cache()
trainer.fit(model=model, datamodule=datamodule)
┏━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃    Name     Type     Params  Mode    FLOPs                       In sizes    Out sizes ┃
┡━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ 0 │ encoder │ Encoder │  987 K │ train │ 27.7 M │                        [2, 7]  [2, 7, 128] │
│ 1 │ decoder │ Decoder │  1.2 M │ train │ 24.9 M │ [[2, 5], [2, 7], [2, 7, 128]]   [2, 5, 16] │
└───┴─────────┴─────────┴────────┴───────┴────────┴───────────────────────────────┴─────────────┘
Trainable params: 2.2 M                                                                                            
Non-trainable params: 0                                                                                            
Total params: 2.2 M                                                                                                
Total estimated model params size (MB): 8                                                                          
Modules in train mode: 125                                                                                         
Modules in eval mode: 0                                                                                            
Total FLOPs: 52.7 M                                                                                                
***** Validation: 样本总数 5000  正确预测: 76  正确率: 0.0152 *****
***** 【Epoch 0】  Train Avg Loss: 1.4259 *****
***** Validation: 样本总数 5000  正确预测: 1124  正确率: 0.2248 *****
***** 【Epoch 1】  Train Avg Loss: 1.0204 *****
***** Validation: 样本总数 5000  正确预测: 4582  正确率: 0.9164 *****
***** 【Epoch 2】  Train Avg Loss: 0.2894 *****
***** Validation: 样本总数 5000  正确预测: 4983  正确率: 0.9966 *****
***** 【Epoch 3】  Train Avg Loss: 0.0340 *****
***** Validation: 样本总数 5000  正确预测: 4999  正确率: 0.9998 *****
***** 【Epoch 4】  Train Avg Loss: 0.0043 *****
***** Validation: 样本总数 5000  正确预测: 5000  正确率: 1.0000 *****
***** 【Epoch 5】  Train Avg Loss: 0.0020 *****
***** Validation: 样本总数 5000  正确预测: 5000  正确率: 1.0000 *****
***** 【Epoch 6】  Train Avg Loss: 0.0012 *****
***** Validation: 样本总数 5000  正确预测: 5000  正确率: 1.0000 *****
***** 【Epoch 7】  Train Avg Loss: 0.0008 *****
***** Validation: 样本总数 5000  正确预测: 5000  正确率: 1.0000 *****
***** 【Epoch 8】  Train Avg Loss: 0.0006 *****
***** Validation: 样本总数 5000  正确预测: 5000  正确率: 1.0000 *****
***** 【Epoch 9】  Train Avg Loss: 0.0005 *****
***** Validation: 样本总数 5000  正确预测: 5000  正确率: 1.0000 *****
***** 【Epoch 10】  Train Avg Loss: 0.0003 *****
`Trainer.fit` stopped: `max_epochs=12` reached.
***** Validation: 样本总数 5000  正确预测: 5000  正确率: 1.0000 *****
***** 【Epoch 11】  Train Avg Loss: 0.0003 *****

21.5.1. 训练过程可视化#

绘制训练过程中损失值的变化曲线。

from dsxllm.util import plot_loss_curves

plot_loss_curves(model.train_epoch_losses)
../_images/ef17c637818476ce3096d46a1c688eed5549db75e49e8797657001e7995de7e6.png

21.5.2. 查看模型评估记录#

查看训练过程中的评估结果,观察模型在验证集上的表现。

from dsxllm.util import to_dataframe

df = to_dataframe(model.validation_epoch_outputs)

display(df)
epoch 总样本数 正确样本数 样本准确率 总Token数 正确Token数 Token准确率
0 0 5000 76 0.0152 20714 9495 0.4584
1 1 5000 1124 0.2248 20714 14741 0.7116
2 2 5000 4582 0.9164 20714 20287 0.9794
3 3 5000 4983 0.9966 20714 20697 0.9992
4 4 5000 4999 0.9998 20714 20713 1.0000
5 5 5000 5000 1.0000 20714 20714 1.0000
6 6 5000 5000 1.0000 20714 20714 1.0000
7 7 5000 5000 1.0000 20714 20714 1.0000
8 8 5000 5000 1.0000 20714 20714 1.0000
9 9 5000 5000 1.0000 20714 20714 1.0000
10 10 5000 5000 1.0000 20714 20714 1.0000
11 11 5000 5000 1.0000 20714 20714 1.0000

21.6. 训练后评估#

与训练前的评估结果对比,确认模型训练效果是否有效。

# 直接调用验证函数进行评估
trainer.validate(model=model, datamodule=datamodule)
/Users/kong/opt/anaconda3/envs/dsx-ai/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
***** Validation: 样本总数 5000  正确预测: 5000  正确率: 1.0000 *****
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric             DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│     correct_sequences               5000.0           │
│      correct_tokens                 20714.0          │
│          seq_acc                      1.0            │
│         token_acc                     1.0            │
│      total_sequences                5000.0           │
│       total_tokens                  20714.0          │
└───────────────────────────┴───────────────────────────┘
[{'total_sequences': 5000.0,
  'correct_sequences': 5000.0,
  'seq_acc': 1.0,
  'total_tokens': 20714.0,
  'correct_tokens': 20714.0,
  'token_acc': 1.0}]

21.7. 使用模型进行预测#

使用一些测试算式进行推理预测,直观观察模型的预测效果。

from dsxllm.util import print_generation_predictions

# 1️⃣ 创建一些测试问题和答案
questions = ["829+33", "58+136", "22+593", "243+269", "1+1"]
answers = ["862", "194", "615", "512", "2"]

# 2️⃣ 使用与训练时统一的数据处理方法对输入进行处理
question_encoded = encoder_transform(questions)
encoder_input_ids = question_encoded["input_ids"]

# 3️⃣ 使用模型进行预测
generated_texts = model.generate_batch(encoder_input_ids)

# 4️⃣ 输出预测结果
print_generation_predictions(questions, answers, generated_texts)
🎯 生成结果 (准确率: 5/5 = 100.00%):
+---------+--------+--------+------+
|   输入  | 真实值 | 预测值 | 标记 |
+---------+--------+--------+------+
|  829+33 |  862   |  862   |  ☑   |
|  58+136 |  194   |  194   |  ☑   |
|  22+593 |  615   |  615   |  ☑   |
| 243+269 |  512   |  512   |  ☑   |
|   1+1   |   2    |   2    |  ☑   |
+---------+--------+--------+------+

21.8. 本章小结#

我们已经完成了使用 Transformer 重构加法计算模型的工作。经过训练和评估,新模型在评估集上的准确率达到了 100%。通过这个小任务,我们亲手实现了 Transformer 的所有组件,深入透彻地掌握了其工作原理,并深刻体会到从循环神经网络到 Transformer 的革命性飞跃。掌握 Transformer,就相当于掌握了破解现代大语言模型的钥匙。从下一章开始,我们将正式开启大语言模型的新篇章。

21.9. 答疑讨论#